{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "UEBilEjLj5wY"
   },
   "source": [
    "Deep Learning Models -- A collection of various deep learning architectures, models, and tips for TensorFlow and PyTorch in Jupyter Notebooks.\n",
    "- Author: Sebastian Raschka\n",
    "- GitHub Repository: https://github.com/rasbt/deeplearning-models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "colab": {
     "autoexec": {
      "startup": false,
      "wait_interval": 0
     },
     "base_uri": "https://localhost:8080/",
     "height": 119
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 536,
     "status": "ok",
     "timestamp": 1524974472601,
     "user": {
      "displayName": "Sebastian Raschka",
      "photoUrl": "//lh6.googleusercontent.com/-cxK6yOSQ6uE/AAAAAAAAAAI/AAAAAAAAIfw/P9ar_CHsKOQ/s50-c-k-no/photo.jpg",
      "userId": "118404394130788869227"
     },
     "user_tz": 240
    },
    "id": "GOzuY8Yvj5wb",
    "outputId": "c19362ce-f87a-4cc2-84cc-8d7b4b9e6007"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sebastian Raschka \n",
      "\n",
      "CPython 3.7.3\n",
      "IPython 7.6.1\n",
      "\n",
      "torch 1.2.0\n"
     ]
    }
   ],
   "source": [
    "%load_ext watermark\n",
    "%watermark -a 'Sebastian Raschka' -v -p torch"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "rH4XmErYj5wm"
   },
   "source": [
    "# LeNet-5 CIFAR10 Classifier"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This notebook implements the classic LeNet-5 convolutional network [1] and applies it to the CIFAR10 object classification dataset. The basic architecture is shown in the figure below:\n",
    "\n",
    "![](../images/lenet/lenet-5_1.jpg)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "\n",
    "LeNet-5 is commonly regarded as the pioneer of convolutional neural networks, consisting of a very simple architecture (by modern standards). In total, LeNet-5 consists of only 7 layers. 3 out of these 7 layers are convolutional layers (C1, C3, C5), which are connected by two average pooling layers (S2 & S4). The penultimate layer is a fully connexted layer (F6), which is followed by the final output layer. The additional details are summarized below:\n",
    "\n",
    "- All convolutional layers use 5x5 kernels with stride 1.\n",
    "- The two average pooling (subsampling) layers are 2x2 pixels wide with stride 1.\n",
    "- Throughrout the network, tanh sigmoid activation functions are used. (**In this notebook, we replace these with ReLU activations**)\n",
    "- The output layer uses 10 custom Euclidean Radial Basis Function neurons for the output layer. (**In this notebook, we replace these with softmax activations**)\n",
    "\n",
    "**Please note that the original architecture was applied to MNIST-like grayscale images (1 color channel). CIFAR10 has 3 color-channels. I found that using the regular architecture results in very poor performance on CIFAR10 (approx. 50% ACC). Hence, I am multiplying the number of kernels by a factor of 3 (according to the 3 color channels) in each layer, which improves is a little bit (approx. 60% Acc).**\n",
    "\n",
    "### References\n",
    "\n",
    "- [1] Y. LeCun, L. Bottou, Y. Bengio, and P. Haffner. Gradient-based learning applied to document recognition. Proceedings of the IEEE, november 1998."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "MkoGLH_Tj5wn"
   },
   "source": [
    "## Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "colab": {
     "autoexec": {
      "startup": false,
      "wait_interval": 0
     }
    },
    "colab_type": "code",
    "id": "ORj09gnrj5wp"
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import time\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "from torchvision import datasets\n",
    "from torchvision import transforms\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "from PIL import Image\n",
    "\n",
    "\n",
    "if torch.cuda.is_available():\n",
    "    torch.backends.cudnn.deterministic = True"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "I6hghKPxj5w0"
   },
   "source": [
    "## Model Settings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "colab": {
     "autoexec": {
      "startup": false,
      "wait_interval": 0
     },
     "base_uri": "https://localhost:8080/",
     "height": 85
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 23936,
     "status": "ok",
     "timestamp": 1524974497505,
     "user": {
      "displayName": "Sebastian Raschka",
      "photoUrl": "//lh6.googleusercontent.com/-cxK6yOSQ6uE/AAAAAAAAAAI/AAAAAAAAIfw/P9ar_CHsKOQ/s50-c-k-no/photo.jpg",
      "userId": "118404394130788869227"
     },
     "user_tz": 240
    },
    "id": "NnT0sZIwj5wu",
    "outputId": "55aed925-d17e-4c6a-8c71-0d9b3bde5637"
   },
   "outputs": [],
   "source": [
    "##########################\n",
    "### SETTINGS\n",
    "##########################\n",
    "\n",
    "# Hyperparameters\n",
    "RANDOM_SEED = 1\n",
    "LEARNING_RATE = 0.001\n",
    "BATCH_SIZE = 128\n",
    "NUM_EPOCHS = 10\n",
    "\n",
    "# Architecture\n",
    "NUM_FEATURES = 32*32\n",
    "NUM_CLASSES = 10\n",
    "\n",
    "# Other\n",
    "DEVICE = \"cuda:0\"\n",
    "GRAYSCALE = False"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### MNIST Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files already downloaded and verified\n",
      "Image batch dimensions: torch.Size([128, 3, 32, 32])\n",
      "Image label dimensions: torch.Size([128])\n",
      "Image batch dimensions: torch.Size([128, 3, 32, 32])\n",
      "Image label dimensions: torch.Size([128])\n"
     ]
    }
   ],
   "source": [
    "##########################\n",
    "### CIFAR-10 Dataset\n",
    "##########################\n",
    "\n",
    "\n",
    "# Note transforms.ToTensor() scales input images\n",
    "# to 0-1 range\n",
    "train_dataset = datasets.CIFAR10(root='data', \n",
    "                                 train=True, \n",
    "                                 transform=transforms.ToTensor(),\n",
    "                                 download=True)\n",
    "\n",
    "test_dataset = datasets.CIFAR10(root='data', \n",
    "                                train=False, \n",
    "                                transform=transforms.ToTensor())\n",
    "\n",
    "\n",
    "train_loader = DataLoader(dataset=train_dataset, \n",
    "                          batch_size=BATCH_SIZE, \n",
    "                          num_workers=8,\n",
    "                          shuffle=True)\n",
    "\n",
    "test_loader = DataLoader(dataset=test_dataset, \n",
    "                         batch_size=BATCH_SIZE,\n",
    "                         num_workers=8,\n",
    "                         shuffle=False)\n",
    "\n",
    "# Checking the dataset\n",
    "for images, labels in train_loader:  \n",
    "    print('Image batch dimensions:', images.shape)\n",
    "    print('Image label dimensions:', labels.shape)\n",
    "    break\n",
    "\n",
    "# Checking the dataset\n",
    "for images, labels in train_loader:  \n",
    "    print('Image batch dimensions:', images.shape)\n",
    "    print('Image label dimensions:', labels.shape)\n",
    "    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 1 | Batch index: 0 | Batch size: 128\n",
      "Epoch: 2 | Batch index: 0 | Batch size: 128\n"
     ]
    }
   ],
   "source": [
    "device = torch.device(DEVICE)\n",
    "torch.manual_seed(0)\n",
    "\n",
    "for epoch in range(2):\n",
    "\n",
    "    for batch_idx, (x, y) in enumerate(train_loader):\n",
    "        \n",
    "        print('Epoch:', epoch+1, end='')\n",
    "        print(' | Batch index:', batch_idx, end='')\n",
    "        print(' | Batch size:', y.size()[0])\n",
    "        \n",
    "        x = x.to(device)\n",
    "        y = y.to(device)\n",
    "        break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "##########################\n",
    "### MODEL\n",
    "##########################\n",
    "\n",
    "\n",
    "class LeNet5(nn.Module):\n",
    "\n",
    "    def __init__(self, num_classes, grayscale=False):\n",
    "        super(LeNet5, self).__init__()\n",
    "        \n",
    "        self.grayscale = grayscale\n",
    "        self.num_classes = num_classes\n",
    "\n",
    "        if self.grayscale:\n",
    "            in_channels = 1\n",
    "        else:\n",
    "            in_channels = 3\n",
    "\n",
    "        self.features = nn.Sequential(\n",
    "            \n",
    "            nn.Conv2d(in_channels, 6*in_channels, kernel_size=5),\n",
    "            nn.MaxPool2d(kernel_size=2),\n",
    "            nn.Conv2d(6*in_channels, 16*in_channels, kernel_size=5),\n",
    "            nn.MaxPool2d(kernel_size=2)\n",
    "        )\n",
    "\n",
    "        self.classifier = nn.Sequential(\n",
    "            nn.Linear(16*5*5*in_channels, 120*in_channels),\n",
    "            nn.Linear(120*in_channels, 84*in_channels),\n",
    "            nn.Linear(84*in_channels, num_classes),\n",
    "        )\n",
    "\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.features(x)\n",
    "        x = torch.flatten(x, 1)\n",
    "        logits = self.classifier(x)\n",
    "        probas = F.softmax(logits, dim=1)\n",
    "        return logits, probas"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "colab": {
     "autoexec": {
      "startup": false,
      "wait_interval": 0
     }
    },
    "colab_type": "code",
    "id": "_lza9t_uj5w1"
   },
   "outputs": [],
   "source": [
    "torch.manual_seed(RANDOM_SEED)\n",
    "\n",
    "model = LeNet5(NUM_CLASSES, GRAYSCALE)\n",
    "model.to(DEVICE)\n",
    "\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)  "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "RAodboScj5w6"
   },
   "source": [
    "## Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "colab": {
     "autoexec": {
      "startup": false,
      "wait_interval": 0
     },
     "base_uri": "https://localhost:8080/",
     "height": 1547
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 2384585,
     "status": "ok",
     "timestamp": 1524976888520,
     "user": {
      "displayName": "Sebastian Raschka",
      "photoUrl": "//lh6.googleusercontent.com/-cxK6yOSQ6uE/AAAAAAAAAAI/AAAAAAAAIfw/P9ar_CHsKOQ/s50-c-k-no/photo.jpg",
      "userId": "118404394130788869227"
     },
     "user_tz": 240
    },
    "id": "Dzh3ROmRj5w7",
    "outputId": "5f8fd8c9-b076-403a-b0b7-fd2d498b48d7"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 001/010 | Batch 0000/0391 | Cost: 2.3071\n",
      "Epoch: 001/010 | Batch 0050/0391 | Cost: 1.8403\n",
      "Epoch: 001/010 | Batch 0100/0391 | Cost: 1.6112\n",
      "Epoch: 001/010 | Batch 0150/0391 | Cost: 1.5836\n",
      "Epoch: 001/010 | Batch 0200/0391 | Cost: 1.4769\n",
      "Epoch: 001/010 | Batch 0250/0391 | Cost: 1.3934\n",
      "Epoch: 001/010 | Batch 0300/0391 | Cost: 1.3729\n",
      "Epoch: 001/010 | Batch 0350/0391 | Cost: 1.2838\n",
      "Epoch: 001/010 | Train: 52.630%\n",
      "Time elapsed: 0.06 min\n",
      "Epoch: 002/010 | Batch 0000/0391 | Cost: 1.4668\n",
      "Epoch: 002/010 | Batch 0050/0391 | Cost: 1.2844\n",
      "Epoch: 002/010 | Batch 0100/0391 | Cost: 1.2798\n",
      "Epoch: 002/010 | Batch 0150/0391 | Cost: 1.3208\n",
      "Epoch: 002/010 | Batch 0200/0391 | Cost: 1.2563\n",
      "Epoch: 002/010 | Batch 0250/0391 | Cost: 1.3398\n",
      "Epoch: 002/010 | Batch 0300/0391 | Cost: 1.2109\n",
      "Epoch: 002/010 | Batch 0350/0391 | Cost: 1.3151\n",
      "Epoch: 002/010 | Train: 56.752%\n",
      "Time elapsed: 0.12 min\n",
      "Epoch: 003/010 | Batch 0000/0391 | Cost: 1.2106\n",
      "Epoch: 003/010 | Batch 0050/0391 | Cost: 1.1814\n",
      "Epoch: 003/010 | Batch 0100/0391 | Cost: 1.3497\n",
      "Epoch: 003/010 | Batch 0150/0391 | Cost: 1.0703\n",
      "Epoch: 003/010 | Batch 0200/0391 | Cost: 1.1444\n",
      "Epoch: 003/010 | Batch 0250/0391 | Cost: 0.8451\n",
      "Epoch: 003/010 | Batch 0300/0391 | Cost: 1.2356\n",
      "Epoch: 003/010 | Batch 0350/0391 | Cost: 1.2627\n",
      "Epoch: 003/010 | Train: 62.042%\n",
      "Time elapsed: 0.18 min\n",
      "Epoch: 004/010 | Batch 0000/0391 | Cost: 1.2457\n",
      "Epoch: 004/010 | Batch 0050/0391 | Cost: 1.1855\n",
      "Epoch: 004/010 | Batch 0100/0391 | Cost: 1.1870\n",
      "Epoch: 004/010 | Batch 0150/0391 | Cost: 1.1477\n",
      "Epoch: 004/010 | Batch 0200/0391 | Cost: 0.9527\n",
      "Epoch: 004/010 | Batch 0250/0391 | Cost: 1.3219\n",
      "Epoch: 004/010 | Batch 0300/0391 | Cost: 0.9374\n",
      "Epoch: 004/010 | Batch 0350/0391 | Cost: 1.0800\n",
      "Epoch: 004/010 | Train: 62.358%\n",
      "Time elapsed: 0.23 min\n",
      "Epoch: 005/010 | Batch 0000/0391 | Cost: 1.1676\n",
      "Epoch: 005/010 | Batch 0050/0391 | Cost: 1.0142\n",
      "Epoch: 005/010 | Batch 0100/0391 | Cost: 1.1620\n",
      "Epoch: 005/010 | Batch 0150/0391 | Cost: 1.0447\n",
      "Epoch: 005/010 | Batch 0200/0391 | Cost: 1.0203\n",
      "Epoch: 005/010 | Batch 0250/0391 | Cost: 1.1567\n",
      "Epoch: 005/010 | Batch 0300/0391 | Cost: 1.2270\n",
      "Epoch: 005/010 | Batch 0350/0391 | Cost: 1.2121\n",
      "Epoch: 005/010 | Train: 65.738%\n",
      "Time elapsed: 0.29 min\n",
      "Epoch: 006/010 | Batch 0000/0391 | Cost: 0.8958\n",
      "Epoch: 006/010 | Batch 0050/0391 | Cost: 0.8708\n",
      "Epoch: 006/010 | Batch 0100/0391 | Cost: 0.8954\n",
      "Epoch: 006/010 | Batch 0150/0391 | Cost: 1.0416\n",
      "Epoch: 006/010 | Batch 0200/0391 | Cost: 0.9596\n",
      "Epoch: 006/010 | Batch 0250/0391 | Cost: 1.1908\n",
      "Epoch: 006/010 | Batch 0300/0391 | Cost: 1.0528\n",
      "Epoch: 006/010 | Batch 0350/0391 | Cost: 1.1561\n",
      "Epoch: 006/010 | Train: 65.396%\n",
      "Time elapsed: 0.35 min\n",
      "Epoch: 007/010 | Batch 0000/0391 | Cost: 1.0105\n",
      "Epoch: 007/010 | Batch 0050/0391 | Cost: 1.0058\n",
      "Epoch: 007/010 | Batch 0100/0391 | Cost: 1.0195\n",
      "Epoch: 007/010 | Batch 0150/0391 | Cost: 0.9968\n",
      "Epoch: 007/010 | Batch 0200/0391 | Cost: 0.9785\n",
      "Epoch: 007/010 | Batch 0250/0391 | Cost: 0.9639\n",
      "Epoch: 007/010 | Batch 0300/0391 | Cost: 0.9576\n",
      "Epoch: 007/010 | Batch 0350/0391 | Cost: 1.1266\n",
      "Epoch: 007/010 | Train: 67.214%\n",
      "Time elapsed: 0.40 min\n",
      "Epoch: 008/010 | Batch 0000/0391 | Cost: 0.8093\n",
      "Epoch: 008/010 | Batch 0050/0391 | Cost: 0.9909\n",
      "Epoch: 008/010 | Batch 0100/0391 | Cost: 0.9171\n",
      "Epoch: 008/010 | Batch 0150/0391 | Cost: 1.0127\n",
      "Epoch: 008/010 | Batch 0200/0391 | Cost: 0.8954\n",
      "Epoch: 008/010 | Batch 0250/0391 | Cost: 1.0231\n",
      "Epoch: 008/010 | Batch 0300/0391 | Cost: 0.8512\n",
      "Epoch: 008/010 | Batch 0350/0391 | Cost: 1.2245\n",
      "Epoch: 008/010 | Train: 66.308%\n",
      "Time elapsed: 0.46 min\n",
      "Epoch: 009/010 | Batch 0000/0391 | Cost: 0.8735\n",
      "Epoch: 009/010 | Batch 0050/0391 | Cost: 0.8906\n",
      "Epoch: 009/010 | Batch 0100/0391 | Cost: 0.8600\n",
      "Epoch: 009/010 | Batch 0150/0391 | Cost: 1.0259\n",
      "Epoch: 009/010 | Batch 0200/0391 | Cost: 0.9805\n",
      "Epoch: 009/010 | Batch 0250/0391 | Cost: 0.9791\n",
      "Epoch: 009/010 | Batch 0300/0391 | Cost: 0.9160\n",
      "Epoch: 009/010 | Batch 0350/0391 | Cost: 0.9888\n",
      "Epoch: 009/010 | Train: 69.292%\n",
      "Time elapsed: 0.52 min\n",
      "Epoch: 010/010 | Batch 0000/0391 | Cost: 0.8405\n",
      "Epoch: 010/010 | Batch 0050/0391 | Cost: 0.9734\n",
      "Epoch: 010/010 | Batch 0100/0391 | Cost: 1.0093\n",
      "Epoch: 010/010 | Batch 0150/0391 | Cost: 0.8113\n",
      "Epoch: 010/010 | Batch 0200/0391 | Cost: 0.9860\n",
      "Epoch: 010/010 | Batch 0250/0391 | Cost: 0.9763\n",
      "Epoch: 010/010 | Batch 0300/0391 | Cost: 0.9767\n",
      "Epoch: 010/010 | Batch 0350/0391 | Cost: 0.9293\n",
      "Epoch: 010/010 | Train: 69.502%\n",
      "Time elapsed: 0.58 min\n",
      "Total Training Time: 0.58 min\n"
     ]
    }
   ],
   "source": [
    "def compute_accuracy(model, data_loader, device):\n",
    "    correct_pred, num_examples = 0, 0\n",
    "    for i, (features, targets) in enumerate(data_loader):\n",
    "            \n",
    "        features = features.to(device)\n",
    "        targets = targets.to(device)\n",
    "\n",
    "        logits, probas = model(features)\n",
    "        _, predicted_labels = torch.max(probas, 1)\n",
    "        num_examples += targets.size(0)\n",
    "        correct_pred += (predicted_labels == targets).sum()\n",
    "    return correct_pred.float()/num_examples * 100\n",
    "    \n",
    "\n",
    "start_time = time.time()\n",
    "for epoch in range(NUM_EPOCHS):\n",
    "    \n",
    "    model.train()\n",
    "    for batch_idx, (features, targets) in enumerate(train_loader):\n",
    "        \n",
    "        features = features.to(DEVICE)\n",
    "        targets = targets.to(DEVICE)\n",
    "            \n",
    "        ### FORWARD AND BACK PROP\n",
    "        logits, probas = model(features)\n",
    "        cost = F.cross_entropy(logits, targets)\n",
    "        optimizer.zero_grad()\n",
    "        \n",
    "        cost.backward()\n",
    "        \n",
    "        ### UPDATE MODEL PARAMETERS\n",
    "        optimizer.step()\n",
    "        \n",
    "        ### LOGGING\n",
    "        if not batch_idx % 50:\n",
    "            print ('Epoch: %03d/%03d | Batch %04d/%04d | Cost: %.4f' \n",
    "                   %(epoch+1, NUM_EPOCHS, batch_idx, \n",
    "                     len(train_loader), cost))\n",
    "\n",
    "        \n",
    "\n",
    "    model.eval()\n",
    "    with torch.set_grad_enabled(False): # save memory during inference\n",
    "        print('Epoch: %03d/%03d | Train: %.3f%%' % (\n",
    "              epoch+1, NUM_EPOCHS, \n",
    "              compute_accuracy(model, train_loader, device=DEVICE)))\n",
    "        \n",
    "    print('Time elapsed: %.2f min' % ((time.time() - start_time)/60))\n",
    "    \n",
    "print('Total Training Time: %.2f min' % ((time.time() - start_time)/60))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "paaeEQHQj5xC"
   },
   "source": [
    "## Evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "colab": {
     "autoexec": {
      "startup": false,
      "wait_interval": 0
     },
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 6514,
     "status": "ok",
     "timestamp": 1524976895054,
     "user": {
      "displayName": "Sebastian Raschka",
      "photoUrl": "//lh6.googleusercontent.com/-cxK6yOSQ6uE/AAAAAAAAAAI/AAAAAAAAIfw/P9ar_CHsKOQ/s50-c-k-no/photo.jpg",
      "userId": "118404394130788869227"
     },
     "user_tz": 240
    },
    "id": "gzQMWKq5j5xE",
    "outputId": "de7dc005-5eeb-4177-9f9f-d9b5d1358db9"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test accuracy: 61.70%\n"
     ]
    }
   ],
   "source": [
    "with torch.set_grad_enabled(False): # save memory during inference\n",
    "    print('Test accuracy: %.2f%%' % (compute_accuracy(model, test_loader, device=DEVICE)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "numpy       1.16.4\n",
      "pandas      0.24.2\n",
      "matplotlib  3.1.0\n",
      "torchvision 0.4.0a0+6b959ee\n",
      "PIL.Image   6.0.0\n",
      "torch       1.2.0\n",
      "\n"
     ]
    }
   ],
   "source": [
    "%watermark -iv"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [],
   "default_view": {},
   "name": "convnet-vgg16.ipynb",
   "provenance": [],
   "version": "0.3.2",
   "views": {}
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  },
  "toc": {
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": true,
   "toc_position": {
    "height": "calc(100% - 180px)",
    "left": "10px",
    "top": "150px",
    "width": "371px"
   },
   "toc_section_display": true,
   "toc_window_display": true
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
